Inference | Hybrid prefix caching.#3225
Open
lmcafee-nvidia wants to merge 13 commits intoNVIDIA:mainfrom
Open
Conversation
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
5944bd7 to
296d684
Compare
27831b6 to
e4a98ee
Compare
Bug fixes: 1. conv_state save-before-read: extract initial conv states BEFORE causal_conv1d_varlen_states + tensor_masked_update overwrites the conv_state buffer. Previously, initial_conv_states was read AFTER the buffer was updated, so restored requests would see their own newly-computed final states instead of the pre-existing initial states, corrupting the convolution output. 2. cu_chunk_seqlens OOB: the SSM Triton kernels allocate per-chunk output arrays of size chunk_size (128). Passing cu_seqlens directly as cu_chunk_seqlens caused out-of-bounds memory access when any sequence exceeded chunk_size tokens. Fix: subdivide each sequence into chunks of at most self.chunk_size, producing correct cu_chunk_seqlens boundaries. 3. zxBCdt padding mismatch: after conv1d, the per-request loop rebuilt xBC with only real tokens while dt and z retained padded token count. This caused a shape assertion failure in the SSM kernel. Fix: strip padded tokens from zxBCdt before _ssm_prefill, then pad the output back to the original padded size for downstream residual add. 4. Per-request conv1d with initial_states: causal_conv1d_fn cannot accept both seq_idx and initial_states simultaneously. The old code passed seq_idx to handle multiple sequences but this zeroes state at sequence boundaries instead of using the cached initial states. Fix: loop over requests, calling causal_conv1d_fn per-request with initial_states and channels-last layout. Improvements: - Unify all Mamba prefill (including chunked) through single varlen SSM kernel call, removing separate chunked-prefill routing and the _batch_indices_chunked_prefill / _device_chunked_prefill metadata - Simplify _dynamic_inference to flat decode + prefill structure - Add _dynamic_inference_prefill helper that strips CUDA-graph padding from metadata and data tensors before calling _ssm_prefill - Remove deprecated constructor parameters (use_mem_eff_path, d_state, headdim, ngroups) and their warnings - Add assertion format string in ssd_combined.py for easier debugging Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…raction With chunk-aligned sequences (one sequence per chunk boundary), the final SSM state for each sequence is simply states[last_chunk_indices], making the separate chunk_state_varlen Triton kernel unnecessary. Construct last_chunk_indices in mamba_mixer.py alongside cu_chunk_seqlens and remove the cu_seqlens parameter from the varlen API since it was only needed by chunk_state_varlen. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
For prefix caching of Mamba layers, extract SSM and conv states at block-aligned chunk boundaries during varlen prefill. Since block_size_tokens % chunk_size == 0, every block boundary falls on a chunk boundary, making intermediate SSM state extraction pure indexing with no extra computation. Conv states are sliced from the pre-conv input tensor. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2c0ded6 to
3224066
Compare
Instead of breaking prefill into chunks at Mamba state boundaries, process full prefill in one kernel call and extract SSM/conv states at specified token offsets. This eliminates the dependency on chunked prefill scheduling and simplifies the engine. Key changes: - Two-map hash design: kv_hash_to_block_id + mamba_hash_to_block_id - Mamba cache infrastructure: GPU memory pool for SSM/conv states - Coupled prefix matching: skip tokens limited by Mamba match count - Intermediate offset computation at KV divergence and last-aligned - Engine passes mamba match count, commits states after forward pass - KV eviction automatically invalidates Mamba state - 18 new tests covering all Mamba caching paths Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Add Mamba state prefix caching for hybrid Transformer-Mamba models, enabling KV cache prefix sharing to also share corresponding Mamba conv/SSM states.
Key Features
--inference-dynamic-batching-prefix-caching-mamba-gbargument controls the memory budget for cached Mamba statesChanges
Core Implementation (
megatron/core/inference/contexts/dynamic_context.py,megatron/core/inference/engines/dynamic_engine.py):Block Allocator (
megatron/core/inference/contexts/dynamic_block_allocator.py):Arguments (
megatron/training/arguments.py):--inference-dynamic-batching-prefix-caching-mamba-gbparameterTest plan
tests/unit_tests/inference/contexts/test_mamba_prefix_caching.py)Contribution process
flowchart LR A[Pre-checks] --> B[PR Tests] subgraph Code Review/Approval C1[Expert Review] --> C2[Final Review] end B --> C1 C2 --> D[Merge]Pre-checks
Core 0.8)Code review
The following process is enforced via the CODEOWNERS file for changes into
megatron/core. For changes outside ofmegatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.For MRs into `main` branch
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
(Step 1): Add PR label
Expert Review(Step 2): Collect the expert reviewers reviews
Expert Reviewlabel when your PR is ready for review.Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
Final Reviewlabel(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into
core_r*release branches, after this PR has been merged, selectCherry-pickto open a new PR into the release branch.For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.Merging your PR
Any member of core-adlr and
core-nemowill be able to merge your PR.